Skip to content

【Draft】[KDA] sm100 GVA enhance#65

Open
sjmshsh wants to merge 2 commits intoinclusionAI:mainfrom
sjmshsh:feat/kda-sm100-gva
Open

【Draft】[KDA] sm100 GVA enhance#65
sjmshsh wants to merge 2 commits intoinclusionAI:mainfrom
sjmshsh:feat/kda-sm100-gva

Conversation

@sjmshsh
Copy link
Copy Markdown

@sjmshsh sjmshsh commented May 7, 2026

Summary

Extend SM100 KDA forward to support num_v_heads > num_qk_heads (GVA), following the pattern already established by the SM90 KDA and by gated_delta_rule GVA.

Branch: feat/kda-sm100-gva, single commit e0e3494.

What's changed

Scheduler & config

  • tile_scheduler.hppParams now carries heads_per_group; decode_tile_coord enumerates tiles in v-head space and returns both v_head_idx and qk_head_idx = v_head_idx / heads_per_group. When HV == HQK, heads_per_group == 1 and behaviour is unchanged.
  • kda_config.hppKDA_fwd_intra_params / KDA_fwd_recomp_w_u_params split h into h_qk, h_v, and cache heads_per_group. Akk and w/u/kg/qg layouts now live in v-head space.

Intra kernel / mainloop

  • Q/K TMA descriptors use shape_QK (total, d, h_qk); g TMA uses shape_VG (total, d, h_v).
  • Load warp slices Q/K with qk_head_idx and g with the v-head index.
  • Aqk row stride and beta stride now use params.h_v.

Recomp W/U kernel / mainloop

  • K/Q TMA descriptors use shape_QK; V/g TMA use shape_VG; Akk TMA uses shape_Akk (total, BT, h_v).
  • Load warp slices K/Q with qk_head_idx and V/g/Akk with the v-head index.
  • w/u/kg/qg write stride and beta stride now use params.h_v.

API / Python

  • csrc/api/kda_sm100.cu – derive h_qk from Q/K and h_v from V/g; assert HV % HQK == 0 plus beta/qg_out shape checks.
  • cula/kda/chunk_intra.py – infer HQK = k.shape[2], HV = v.shape[2]; allocate Aqk, Akk, w, kg, qg in v-head space; add shape assertions.

Backward compatibility

When HV == HQK:

  • heads_per_group == 1
  • qk_head_idx == v_head_idx
  • shape_qk == shape_vg
  • All strides and shapes reduce to the pre-GVA layout.

No existing HV == HQK workloads should observe any behavioural change.

Known follow-ups (not part of this PR)

The end-to-end SM100 path (chunk_kda_fwd in cula/kda/chunk_fwd.py) feeds the intra/recomp outputs into chunk_gated_delta_rule_fwd_h and chunk_gla_fwd_o, which currently assume q/v/g/A/o share the same head count. This PR intentionally leaves those two CuTe kernels untouched (SM90 does not go through them, so mirroring SM90 leaves them out of scope). A follow-up PR is needed to teach those two kernels GVA before the full SM100 pipeline runs with HV > HQK.

Testing

  • Local compile (pending SM100 toolchain access).
  • Numerical parity vs SM90 KDA under HV == HQK.
  • Numerical correctness under HV > HQK after the downstream CuTe kernels are GVA-ready.

Draft because the downstream GVA-enablement work and end-to-end validation are still pending.

#55

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements Grouped V-head Attention (GVA) support across the KDA kernels for both SM90 and SM100 architectures. Key changes include decoupling head counts for Q/K and V/G tensors, updating TMA descriptors and tile scheduling logic to handle these grouped configurations, and adding comprehensive validation checks. The Python API and test suite have been updated to support and verify GVA functionality. Feedback from the review identifies a documentation mismatch regarding tensor layouts in the SM100 mainloop and suggests correcting terminology in Python error messages to distinguish between head count and head dimension.

int row = (idx_in_wg / 32) * 16 + (idx_in_wg % 16);

// GMEM output address: layout [total_len, d, h], stride [d*h, 1, d]
// GMEM output address: layout [total_len, d, h_v], stride [d*h_v, 1, d]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment mentions layout [total_len, d, h_v], but the stride [d*h_v, 1, d] and the code logic actually correspond to a [total_len, h_v, d] layout (where d is the inner-most dimension).

                // GMEM output address: layout [total_len, h_v, d], stride [d*h_v, 1, d]

Comment thread cula/kda/chunk_intra.py
f"v must share (B, T) with k; got k.shape={k.shape}, v.shape={v.shape}"
)
assert HV > 0 and HQK > 0 and HV % HQK == 0, (
f"v head-dim (HV={HV}) must be a positive multiple of k head-dim (HQK={HQK})"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message incorrectly uses the term 'head-dim' when referring to HV and HQK, which represent the number of heads (head count). The head dimension is represented by K.

Suggested change
f"v head-dim (HV={HV}) must be a positive multiple of k head-dim (HQK={HQK})"
f"v head count (HV={HV}) must be a positive multiple of k head count (HQK={HQK})"

Follow the GVA pattern used in the SM90 KDA (and in gated_delta_rule GVA) so that the SM100 KDA forward pass can handle num_v_heads > num_qk_heads.

C++ changes:

- tile_scheduler: Params now carries heads_per_group; decode_tile_coord enumerates tiles in v-head space and returns both v_head_idx and qk_head_idx (= v_head_idx / heads_per_group). When HV == HQK this degenerates to the previous behaviour.

- kda_config: KDA_fwd_intra_params / KDA_fwd_recomp_w_u_params split h into h_qk and h_v and cache heads_per_group; Akk and w/u/kg/qg layouts now live in v-head space.

- intra kernel/mainloop: Q/K TMA descriptors use shape_QK (total, d, h_qk); g TMA uses shape_VG (total, d, h_v). Load warp slices Q/K with qk_head_idx and g with v_head_idx; Aqk row stride and beta stride now use params.h_v.

- recomp_w_u kernel/mainloop: K/Q TMA descriptors use shape_QK; V/g TMA use shape_VG; Akk TMA uses shape_Akk (total, BT, h_v). Load warp slices K/Q with qk_head_idx and V/g/Akk with v_head_idx; w/u/kg/qg write stride and beta stride now use params.h_v.

API / Python:

- kda_sm100.cu: derive h_qk from Q/K and h_v from V/g; validate HV % HQK == 0 and beta/qg_out shapes.

- cula/kda/chunk_intra.py: infer HQK from k.shape[2] and HV from v.shape[2]; allocate Aqk, Akk, w, kg, qg in v-head space; add shape assertions.

Backward compatible: when HV == HQK, heads_per_group == 1 and qk_head_idx == v_head_idx, and all shapes/strides reduce to the pre-GVA layout.
@sjmshsh sjmshsh force-pushed the feat/kda-sm100-gva branch from e0e3494 to 58535e2 Compare May 7, 2026 03:02
@sjmshsh sjmshsh changed the title [KDA] sm100 GVA enhance 【Draft】[KDA] sm100 GVA enhance May 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant